{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vY4SK0xKAJgm"
   },
   "source": [
    "# Model Zoo -- Simple RNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sc6xejhY-NzZ"
   },
   "source": [
    "Demo of a simple RNN for sentiment classification (here: a binary classification problem with two labels, positive and negative). Note that a simple RNN usually doesn't work very well due to vanishing and exploding gradient problems. Also, this implementation uses padding for dealing with variable size inputs. Hence, the shorter the sentence, the more `<pad>` placeholders will be added to match the length of the longest sentence in a batch.\n",
    "\n",
    "Note that this RNN trains about 4 times slower than the equivalent with packed sequences, [./rnn-simple-packed-imdb.ipynb](./rnn-simple-packed-imdb.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "moNmVfuvnImW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.1\n",
      "IPython 7.4.0\n",
      "\n",
      "torch 1.0.1.post2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torchtext import data\n",
    "from torchtext import datasets\n",
    "import time\n",
    "import random\n",
    "\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GSRL42Qgy8I8"
   },
   "source": [
    "## General Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OvW1RgfepCBq"
   },
   "outputs": [],
   "source": [
    "RANDOM_SEED = 123\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "VOCABULARY_SIZE = 20000\n",
    "LEARNING_RATE = 1e-4\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 15\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "EMBEDDING_DIM = 128\n",
    "HIDDEN_DIM = 256\n",
    "OUTPUT_DIM = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mQMmKUEisW4W"
   },
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4GnH64XvsV8n"
   },
   "source": [
    "Load the IMDB Movie Review dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "WZ_4jiHVnMxN",
    "outputId": "7a3115ba-e294-46d4-aeb0-a8627b027f98"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "downloading aclImdb_v1.tar.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:03<00:00, 22.3MB/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Train: 20000\n",
      "Num Valid: 5000\n",
      "Num Test: 25000\n"
     ]
    }
   ],
   "source": [
    "TEXT = data.Field(tokenize = 'spacy')\n",
    "LABEL = data.LabelField(dtype = torch.float)\n",
    "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
    "train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED),\n",
    "                                          split_ratio=0.8)\n",
    "\n",
    "print(f'Num Train: {len(train_data)}')\n",
    "print(f'Num Valid: {len(valid_data)}')\n",
    "print(f'Num Test: {len(test_data)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L-TBwKWPslPa"
   },
   "source": [
    "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "e8uNrjdtn4A8",
    "outputId": "2b653c07-da9f-4593-8b48-5571daf0e661"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size: 20002\n",
      "Number of classes: 2\n"
     ]
    }
   ],
   "source": [
    "TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)\n",
    "LABEL.build_vocab(train_data)\n",
    "\n",
    "print(f'Vocabulary size: {len(TEXT.vocab)}')\n",
    "print(f'Number of classes: {len(LABEL.vocab)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JpEMNInXtZsb"
   },
   "source": [
    "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `<unk>` and `<pad>`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eIQ_zfKLwjKm"
   },
   "source": [
    "Make dataset iterators:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i7JiHR1stHNF"
   },
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size=BATCH_SIZE,\n",
    "    device=DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R0pT_dMRvicQ"
   },
   "source": [
    "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "y8SP_FccutT0",
    "outputId": "53624f67-6649-4bd6-8af3-95b0529c43f7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train\n",
      "Text matrix size: torch.Size([880, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Valid:\n",
      "Text matrix size: torch.Size([61, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Test:\n",
      "Text matrix size: torch.Size([42, 128])\n",
      "Target vector size: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "print('Train')\n",
    "for batch in train_loader:\n",
    "    print(f'Text matrix size: {batch.text.size()}')\n",
    "    print(f'Target vector size: {batch.label.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nValid:')\n",
    "for batch in valid_loader:\n",
    "    print(f'Text matrix size: {batch.text.size()}')\n",
    "    print(f'Target vector size: {batch.label.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nTest:')\n",
    "for batch in test_loader:\n",
    "    print(f'Text matrix size: {batch.text.size()}')\n",
    "    print(f'Target vector size: {batch.label.size()}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G_grdW3pxCzz"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nQIUm5EjxFNa"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(input_dim, embedding_dim)\n",
    "        self.rnn = nn.RNN(embedding_dim, hidden_dim)\n",
    "        self.fc = nn.Linear(hidden_dim, output_dim)\n",
    "        \n",
    "    def forward(self, text):\n",
    "\n",
    "        #[sentence len, batch size] => [sentence len, batch size, embedding size]\n",
    "        embedded = self.embedding(text)\n",
    "        \n",
    "        #[sentence len, batch size, embedding size] => \n",
    "        #  output: [sentence len, batch size, hidden size]\n",
    "        #  hidden: [1, batch size, hidden size]\n",
    "        output, hidden = self.rnn(embedded)\n",
    "        \n",
    "        return self.fc(hidden.squeeze(0)).view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ik3NF3faxFmZ"
   },
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "EMBEDDING_DIM = 64\n",
    "HIDDEN_DIM = 128\n",
    "OUTPUT_DIM = 1\n",
    "\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lv9Ny9di6VcI"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T5t1Afn4xO11"
   },
   "outputs": [],
   "source": [
    "def compute_binary_accuracy(model, data_loader, device):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch_data in enumerate(data_loader):\n",
    "            logits = model(batch_data.text)\n",
    "            predicted_labels = (torch.sigmoid(logits) > 0.5).long()\n",
    "            num_examples += batch_data.label.size(0)\n",
    "            correct_pred += (predicted_labels == batch_data.label.long()).sum()\n",
    "        return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1836
    },
    "colab_type": "code",
    "id": "EABZM8Vo0ilB",
    "outputId": "ad5a6981-d308-4c2b-ee26-8de50303591d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/015 | Batch 000/157 | Cost: 0.7111\n",
      "Epoch: 001/015 | Batch 050/157 | Cost: 0.6912\n",
      "Epoch: 001/015 | Batch 100/157 | Cost: 0.6856\n",
      "Epoch: 001/015 | Batch 150/157 | Cost: 0.6970\n",
      "training accuracy: 49.94%\n",
      "valid accuracy: 49.96%\n",
      "Time elapsed: 0.42 min\n",
      "Epoch: 002/015 | Batch 000/157 | Cost: 0.6905\n",
      "Epoch: 002/015 | Batch 050/157 | Cost: 0.6980\n",
      "Epoch: 002/015 | Batch 100/157 | Cost: 0.6934\n",
      "Epoch: 002/015 | Batch 150/157 | Cost: 0.6927\n",
      "training accuracy: 49.99%\n",
      "valid accuracy: 49.86%\n",
      "Time elapsed: 0.83 min\n",
      "Epoch: 003/015 | Batch 000/157 | Cost: 0.6947\n",
      "Epoch: 003/015 | Batch 050/157 | Cost: 0.6938\n",
      "Epoch: 003/015 | Batch 100/157 | Cost: 0.7035\n",
      "Epoch: 003/015 | Batch 150/157 | Cost: 0.6942\n",
      "training accuracy: 49.99%\n",
      "valid accuracy: 50.60%\n",
      "Time elapsed: 1.26 min\n",
      "Epoch: 004/015 | Batch 000/157 | Cost: 0.6927\n",
      "Epoch: 004/015 | Batch 050/157 | Cost: 0.6920\n",
      "Epoch: 004/015 | Batch 100/157 | Cost: 0.6916\n",
      "Epoch: 004/015 | Batch 150/157 | Cost: 0.6947\n",
      "training accuracy: 50.07%\n",
      "valid accuracy: 49.80%\n",
      "Time elapsed: 1.68 min\n",
      "Epoch: 005/015 | Batch 000/157 | Cost: 0.6885\n",
      "Epoch: 005/015 | Batch 050/157 | Cost: 0.6907\n",
      "Epoch: 005/015 | Batch 100/157 | Cost: 0.6939\n",
      "Epoch: 005/015 | Batch 150/157 | Cost: 0.6881\n",
      "training accuracy: 50.09%\n",
      "valid accuracy: 49.86%\n",
      "Time elapsed: 2.09 min\n",
      "Epoch: 006/015 | Batch 000/157 | Cost: 0.6939\n",
      "Epoch: 006/015 | Batch 050/157 | Cost: 0.6928\n",
      "Epoch: 006/015 | Batch 100/157 | Cost: 0.6917\n",
      "Epoch: 006/015 | Batch 150/157 | Cost: 0.6915\n",
      "training accuracy: 49.99%\n",
      "valid accuracy: 50.54%\n",
      "Time elapsed: 2.53 min\n",
      "Epoch: 007/015 | Batch 000/157 | Cost: 0.6927\n",
      "Epoch: 007/015 | Batch 050/157 | Cost: 0.6935\n",
      "Epoch: 007/015 | Batch 100/157 | Cost: 0.6931\n",
      "Epoch: 007/015 | Batch 150/157 | Cost: 0.6917\n",
      "training accuracy: 50.05%\n",
      "valid accuracy: 50.18%\n",
      "Time elapsed: 2.95 min\n",
      "Epoch: 008/015 | Batch 000/157 | Cost: 0.6921\n",
      "Epoch: 008/015 | Batch 050/157 | Cost: 0.6940\n",
      "Epoch: 008/015 | Batch 100/157 | Cost: 0.6923\n",
      "Epoch: 008/015 | Batch 150/157 | Cost: 0.6877\n",
      "training accuracy: 50.06%\n",
      "valid accuracy: 49.82%\n",
      "Time elapsed: 3.37 min\n",
      "Epoch: 009/015 | Batch 000/157 | Cost: 0.6926\n",
      "Epoch: 009/015 | Batch 050/157 | Cost: 0.6980\n",
      "Epoch: 009/015 | Batch 100/157 | Cost: 0.6970\n",
      "Epoch: 009/015 | Batch 150/157 | Cost: 0.6900\n",
      "training accuracy: 50.19%\n",
      "valid accuracy: 49.36%\n",
      "Time elapsed: 3.80 min\n",
      "Epoch: 010/015 | Batch 000/157 | Cost: 0.6954\n",
      "Epoch: 010/015 | Batch 050/157 | Cost: 0.6926\n",
      "Epoch: 010/015 | Batch 100/157 | Cost: 0.6916\n",
      "Epoch: 010/015 | Batch 150/157 | Cost: 0.6926\n",
      "training accuracy: 50.01%\n",
      "valid accuracy: 50.16%\n",
      "Time elapsed: 4.22 min\n",
      "Epoch: 011/015 | Batch 000/157 | Cost: 0.6933\n",
      "Epoch: 011/015 | Batch 050/157 | Cost: 0.6933\n",
      "Epoch: 011/015 | Batch 100/157 | Cost: 0.6947\n",
      "Epoch: 011/015 | Batch 150/157 | Cost: 0.6922\n",
      "training accuracy: 50.17%\n",
      "valid accuracy: 49.88%\n",
      "Time elapsed: 4.64 min\n",
      "Epoch: 012/015 | Batch 000/157 | Cost: 0.6927\n",
      "Epoch: 012/015 | Batch 050/157 | Cost: 0.6934\n",
      "Epoch: 012/015 | Batch 100/157 | Cost: 0.6931\n",
      "Epoch: 012/015 | Batch 150/157 | Cost: 0.6934\n",
      "training accuracy: 50.15%\n",
      "valid accuracy: 49.92%\n",
      "Time elapsed: 5.08 min\n",
      "Epoch: 013/015 | Batch 000/157 | Cost: 0.6938\n",
      "Epoch: 013/015 | Batch 050/157 | Cost: 0.6946\n",
      "Epoch: 013/015 | Batch 100/157 | Cost: 0.6956\n",
      "Epoch: 013/015 | Batch 150/157 | Cost: 0.6925\n",
      "training accuracy: 50.10%\n",
      "valid accuracy: 50.38%\n",
      "Time elapsed: 5.51 min\n",
      "Epoch: 014/015 | Batch 000/157 | Cost: 0.6940\n",
      "Epoch: 014/015 | Batch 050/157 | Cost: 0.6917\n",
      "Epoch: 014/015 | Batch 100/157 | Cost: 0.6902\n",
      "Epoch: 014/015 | Batch 150/157 | Cost: 0.6961\n",
      "training accuracy: 50.13%\n",
      "valid accuracy: 50.36%\n",
      "Time elapsed: 5.93 min\n",
      "Epoch: 015/015 | Batch 000/157 | Cost: 0.6985\n",
      "Epoch: 015/015 | Batch 050/157 | Cost: 0.6916\n",
      "Epoch: 015/015 | Batch 100/157 | Cost: 0.6879\n",
      "Epoch: 015/015 | Batch 150/157 | Cost: 0.6934\n",
      "training accuracy: 50.16%\n",
      "valid accuracy: 50.68%\n",
      "Time elapsed: 6.35 min\n",
      "Total Training Time: 6.35 min\n",
      "Test accuracy: 46.38%\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    for batch_idx, batch_data in enumerate(train_loader):\n",
    "        \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits = model(batch_data.text)\n",
    "        cost = F.binary_cross_entropy_with_logits(logits, batch_data.label)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n",
    "                   f'Cost: {cost:.4f}')\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print(f'training accuracy: '\n",
    "              f'{compute_binary_accuracy(model, train_loader, DEVICE):.2f}%'\n",
    "              f'\\nvalid accuracy: '\n",
    "              f'{compute_binary_accuracy(model, valid_loader, DEVICE):.2f}%')\n",
    "        \n",
    "    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n",
    "    \n",
    "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n",
    "print(f'Test accuracy: {compute_binary_accuracy(model, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0fFkgUdUJOzD"
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en')\n",
    "\n",
    "def predict_sentiment(model, sentence):\n",
    "    # based on:\n",
    "    # https://github.com/bentrevett/pytorch-sentiment-analysis/blob/\n",
    "    # master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    tensor = torch.LongTensor(indexed).to(DEVICE)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    prediction = torch.sigmoid(model(tensor))\n",
    "    return prediction.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "WE9axsgOJQaj",
    "outputId": "6e0be8c9-6c47-413f-b6c6-5224fe352816"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability positive:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.5701386332511902"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('Probability positive:')\n",
    "predict_sentiment(model, \"I really love this movie. This movie is so great!\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "rnn_simple_imdb.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
